{
    "nbformat": 4,
    "nbformat_minor": 0,
    "metadata": {
        "accelerator": "GPU",
        "colab": {
            "name": "Speech_Commands.ipynb",
            "provenance": [],
            "collapsed_sections": [],
            "toc_visible": true
        },
        "kernelspec": {
            "display_name": "Python 3",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.7.7"
        },
        "pycharm": {
            "stem_cell": {
                "cell_type": "raw",
                "source": [],
                "metadata": {
                    "collapsed": false
                }
            }
        }
    },
    "cells": [
        {
            "cell_type": "code",
            "metadata": {
                "id": "R12Yn6W1dt9t"
            },
            "source": [
                "\"\"\"\n",
                "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
                "\n",
                "Instructions for setting up Colab are as follows:\n",
                "1. Open a new Python 3 notebook.\n",
                "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
                "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
                "4. Run this cell to set up dependencies.\n",
                "\n\nNOTE: User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use.\n",
                "\"\"\"\n",
                "# If you're using Google Colab and not running locally, run this cell.\n",
                "\n",
                "## Install dependencies\n",
                "!pip install wget\n",
                "!apt-get install sox libsndfile1 ffmpeg\n",
                "!pip install text-unidecode\n",
                "\n",
                "# ## Install NeMo\n",
                "BRANCH = 'main'\n",
                "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]\n",
                "\n",
                "## Install TorchAudio\n",
                "## NOTE: TorchAudio installation may not work in all environments, please use Google Colab for best experience\n",
                "!pip install torchaudio>=0.13.0 -f https://download.pytorch.org/whl/torch_stable.html\n",
                "\n",
                "## Grab the config we'll use in this example\n",
                "!mkdir configs"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "J6ycGIaZfSLE"
            },
            "source": [
                "# Introduction\n",
                "\n",
                "This Speech Command recognition tutorial is based on the MatchboxNet model from the paper [\"MatchboxNet: 1D Time-Channel Separable Convolutional Neural Network Architecture for Speech Commands Recognition\"](https://arxiv.org/abs/2004.08531). MatchboxNet is a modified form of the QuartzNet architecture from the paper \"[QuartzNet: Deep Automatic Speech Recognition with 1D Time-Channel Separable Convolutions](https://arxiv.org/pdf/1910.10261.pdf)\" with a modified decoder head to suit classification tasks.\n",
                "\n",
                "The notebook will follow the steps below:\n",
                "\n",
                " - Dataset preparation: Preparing Google Speech Commands dataset\n",
                "\n",
                " - Audio preprocessing (feature extraction): signal normalization, windowing, (log) spectrogram (or mel scale spectrogram, or MFCC)\n",
                "\n",
                " - Data augmentation using SpecAugment \"[SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition](https://arxiv.org/abs/1904.08779)\" to increase the number of data samples.\n",
                " \n",
                " - Develop a small Neural classification model that can be trained efficiently.\n",
                " \n",
                " - Model training on the Google Speech Commands dataset in NeMo.\n",
                " \n",
                " - Evaluation of error cases of the model by audibly hearing the samples"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "I62_LJzc-p2b"
            },
            "source": [
                "# Some utility imports\n",
                "import os\n",
                "from omegaconf import OmegaConf"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "K_M8wpkwd7d7"
            },
            "source": [
                "# This is where the Google Speech Commands directory will be placed.\n",
                "# Change this if you don't want the data to be extracted in the current directory.\n",
                "# Select the version of the dataset required as well (can be 1 or 2)\n",
                "DATASET_VER = 1\n",
                "data_dir = './google_dataset_v{0}/'.format(DATASET_VER)\n",
                "\n",
                "if DATASET_VER == 1:\n",
                "  MODEL_CONFIG = \"matchboxnet_3x1x64_v1.yaml\"\n",
                "else:\n",
                "  MODEL_CONFIG = \"matchboxnet_3x1x64_v2.yaml\"\n",
                "\n",
                "if not os.path.exists(f\"configs/{MODEL_CONFIG}\"):\n",
                "  !wget -P configs/ \"https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/matchboxnet/{MODEL_CONFIG}\""
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "tvfwv9Hjf1Uv"
            },
            "source": [
                "# Data Preparation\n",
                "\n",
                "We will be using the open-source Google Speech Commands Dataset (we will use V1 of the dataset for the tutorial but require minor changes to support the V2 dataset). These scripts below will download the dataset and convert it to a format suitable for use with NeMo."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "6VL10OXTf8ts"
            },
            "source": [
                "## Download the dataset\n",
                "\n",
                "The dataset must be prepared using the scripts provided under the `{NeMo root directory}/scripts` sub-directory. \n",
                "\n",
                "Run the following command below to download the data preparation script and execute it.\n",
                "\n",
                "**NOTE**: You should have at least 4GB of disk space available if you’ve used --data_version=1; and at least 6GB if you used --data_version=2. Also, it will take some time to download and process, so go grab a coffee.\n",
                "\n",
                "**NOTE**: You may additionally pass a `--rebalance` flag at the end of the `process_speech_commands_data.py` script to rebalance the class samples in the manifest."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "oqKe6_uLfzKU"
            },
            "source": [
                "if not os.path.exists(\"process_speech_commands_data.py\"):\n",
                "  !wget https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/scripts/dataset_processing/process_speech_commands_data.py"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "TTsxp0nZ1zqo"
            },
            "source": [
                "### Preparing the manifest file\n",
                "\n",
                "The manifest file is a simple file that has the full path to the audio file, the duration of the audio file, and the label that is assigned to that audio file. \n",
                "\n",
                "This notebook is only a demonstration, and therefore we will use the `--skip_duration` flag to speed up construction of the manifest file.\n",
                "\n",
                "**NOTE: When replicating the results of the paper, do not use this flag and prepare the manifest file with correct durations.**"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "cWUtDpzKgop9"
            },
            "source": [
                "!mkdir {data_dir}\n",
                "!python process_speech_commands_data.py --data_root={data_dir} --data_version={DATASET_VER} --skip_duration --log\n",
                "print(\"Dataset ready !\")"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "eVsPFxJtg30p"
            },
            "source": [
                "## Prepare the path to manifest files"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "ytTFGVe0g9wk"
            },
            "source": [
                "dataset_path = 'google_speech_recognition_v{0}'.format(DATASET_VER)\n",
                "dataset_basedir = os.path.join(data_dir, dataset_path)\n",
                "\n",
                "train_dataset = os.path.join(dataset_basedir, 'train_manifest.json')\n",
                "val_dataset = os.path.join(dataset_basedir, 'validation_manifest.json')\n",
                "test_dataset = os.path.join(dataset_basedir, 'validation_manifest.json')"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "s0SZy9SEhOBf"
            },
            "source": [
                "## Read a few rows of the manifest file \n",
                "\n",
                "Manifest files are the data structure used by NeMo to declare a few important details about the data :\n",
                "\n",
                "1) `audio_filepath`: Refers to the path to the raw audio file <br>\n",
                "2) `command`: The class label (or speech command) of this sample <br>\n",
                "3) `duration`: The length of the audio file, in seconds."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "HYBidCMIhKQV"
            },
            "source": [
                "!head -n 5 {train_dataset}"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "r-pyUBedh8f4"
            },
            "source": [
                "# Training - Preparation\n",
                "\n",
                "We will be training a MatchboxNet model from the paper [\"MatchboxNet: 1D Time-Channel Separable Convolutional Neural Network Architecture for Speech Commands Recognition\"](https://arxiv.org/abs/2004.08531). The benefit of MatchboxNet over JASPER models is that they use 1D Time-Channel Separable Convolutions, which greatly reduce the number of parameters required to obtain good model accuracy.\n",
                "\n",
                "MatchboxNet models generally follow the model definition pattern QuartzNet-[BxRXC], where B is the number of blocks, R is the number of convolutional sub-blocks, and C is the number of channels in these blocks. Each sub-block contains a 1-D masked convolution, batch normalization, ReLU, and dropout.\n",
                "\n",
                "An image of QuartzNet, the base configuration of MatchboxNet models, is provided below.\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "T0sV4riijHJF"
            },
            "source": [
                "<p align=\"center\">\n",
                "  <img src=\"https://developer.nvidia.com/blog/wp-content/uploads/2020/05/quartznet-model-architecture-1-625x742.png\">\n",
                "</p>"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "ieAPOM9thTN2"
            },
            "source": [
                "# NeMo's \"core\" package\n",
                "import nemo\n",
                "# NeMo's ASR collection - this collections contains complete ASR models and\n",
                "# building blocks (modules) for ASR\n",
                "import nemo.collections.asr as nemo_asr"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "ss9gLcDv30jI"
            },
            "source": [
                "## Model Configuration\n",
                "The MatchboxNet Model is defined in a config file which declares multiple important sections.\n",
                "\n",
                "They are:\n",
                "\n",
                "1) `model`: All arguments that will relate to the Model - preprocessors, encoder, decoder, optimizer and schedulers, datasets and any other related information\n",
                "\n",
                "2) `trainer`: Any argument to be passed to PyTorch Lightning"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "yoVAs9h1lfci"
            },
            "source": [
                "# This line will print the entire config of the MatchboxNet model\n",
                "config_path = f\"configs/{MODEL_CONFIG}\"\n",
                "config = OmegaConf.load(config_path)\n",
                "config = OmegaConf.to_container(config, resolve=True)\n",
                "config = OmegaConf.create(config)\n",
                "print(OmegaConf.to_yaml(config))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "m2lJPR0a3qww"
            },
            "source": [
                "# Preserve some useful parameters\n",
                "labels = config.model.labels\n",
                "sample_rate = config.model.sample_rate"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "8_pmjeed78rJ"
            },
            "source": [
                "### Setting up the datasets within the config\n",
                "\n",
                "If you'll notice, there are a few config dictionaries called `train_ds`, `validation_ds` and `test_ds`. These are configurations used to setup the Dataset and DataLoaders of the corresponding config.\n",
                "\n"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "DIe6Qfs18MiQ"
            },
            "source": [
                "print(OmegaConf.to_yaml(config.model.train_ds))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "Fb01hl868Uc3"
            },
            "source": [
                "### `???` inside configs\n",
                "\n",
                "You will often notice that some configs have `???` in place of paths. This is used as a placeholder so that the user can change the value at a later time.\n",
                "\n",
                "Let's add the paths to the manifests to the config above."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "m181HXev8T97"
            },
            "source": [
                "config.model.train_ds.manifest_filepath = train_dataset\n",
                "config.model.validation_ds.manifest_filepath = val_dataset\n",
                "config.model.test_ds.manifest_filepath = test_dataset"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "pbXngoCM5IRG"
            },
            "source": [
                "## Building the PyTorch Lightning Trainer\n",
                "\n",
                "NeMo models are primarily PyTorch Lightning modules - and therefore are entirely compatible with the PyTorch Lightning ecosystem!\n",
                "\n",
                "Lets first instantiate a Trainer object!"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "bYtvdBlG5afU"
            },
            "source": [
                "import torch\n",
                "import lightning.pytorch as pl"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "jRN18CdH51nN"
            },
            "source": [
                "print(\"Trainer config - \\n\")\n",
                "print(OmegaConf.to_yaml(config.trainer))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "gHf6cHvm6H9b"
            },
            "source": [
                "# Lets modify some trainer configs for this demo\n",
                "# Checks if we have GPU available and uses it\n",
                "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n",
                "config.trainer.devices = 1\n",
                "config.trainer.accelerator = accelerator\n",
                "\n",
                "# Reduces maximum number of epochs to 5 for quick demonstration\n",
                "config.trainer.max_epochs = 5\n",
                "\n",
                "# Remove distributed training flags\n",
                "config.trainer.strategy = 'auto'"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "UB9nr7G56G3L"
            },
            "source": [
                "trainer = pl.Trainer(**config.trainer)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "2wt603Vq6sqX"
            },
            "source": [
                "## Setting up a NeMo Experiment\n",
                "\n",
                "NeMo has an experiment manager that handles logging and checkpointing for us, so let's use it ! "
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "TfWJFg7p6Ezf"
            },
            "source": [
                "from nemo.utils.exp_manager import exp_manager"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "SC-QPoW44-p2"
            },
            "source": [
                "exp_dir = exp_manager(trainer, config.get(\"exp_manager\", None))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "Yqi6rkNR7Dph"
            },
            "source": [
                "# The exp_dir provides a path to the current experiment for easy access\n",
                "exp_dir = str(exp_dir)\n",
                "exp_dir"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "t0zz-vHH7Uuh"
            },
            "source": [
                "## Building the MatchboxNet Model\n",
                "\n",
                "MatchboxNet is an ASR model with a classification task - it generates one label for the entire provided audio stream. Therefore we encapsulate it inside the `EncDecClassificationModel` as follows."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "FRMrKhyf5vhy"
            },
            "source": [
                "asr_model = nemo_asr.models.EncDecClassificationModel(cfg=config.model, trainer=trainer)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "jA9UND-Q_oyw"
            },
            "source": [
                "# Training a MatchboxNet Model\n",
                "\n",
                "As MatchboxNet is inherently a PyTorch Lightning Model, it can easily be trained in a single line - `trainer.fit(model)` !"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "3ngKcRFqBfIF"
            },
            "source": [
                "### Monitoring training progress\n",
                "\n",
                "Before we begin training, let's first create a Tensorboard visualization to monitor progress\n"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "sT3371CbJ8Rz"
            },
            "source": [
                "try:\n",
                "  from google import colab\n",
                "  COLAB_ENV = True\n",
                "except (ImportError, ModuleNotFoundError):\n",
                "  COLAB_ENV = False"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "Cyfec0PDBsXa"
            },
            "source": [
                "# Load the TensorBoard notebook extension\n",
                "if COLAB_ENV:\n",
                "  %load_ext tensorboard\n",
                "else:\n",
                "  print(\"To use tensorboard, please use this notebook in a Google Colab environment.\")"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "4L5ymu-QBxmz"
            },
            "source": [
                "if COLAB_ENV:\n",
                "  %tensorboard --logdir {exp_dir}\n",
                "else:\n",
                "  print(\"To use tensorboard, please use this notebook in a Google Colab environment.\")"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "ZApuELDIKQgC"
            },
            "source": [
                "### Training for 5 epochs\n",
                "We see below that the model begins to get modest scores on the validation set after just 5 epochs of training"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "9xiUUJlH5KdD"
            },
            "source": [
                "trainer.fit(asr_model)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "Dkds1jSvKgSc"
            },
            "source": [
                "### Evaluation on the Test set\n",
                "\n",
                "Lets compute the final score on the test set via `trainer.test(model)`"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "mULTrhEJ_6wV"
            },
            "source": [
                "trainer.test(asr_model, ckpt_path=None)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "XQntce8cLiUC"
            },
            "source": [
                "# Fast Training\n",
                "\n",
                "We can dramatically improve the time taken to train this model by using Multi GPU training along with Mixed Precision.\n",
                "\n",
                "```python\n",
                "# Trainer with a distributed backend:\n",
                "trainer = Trainer(devices=2, num_nodes=2, accelerator='gpu', strategy='auto')\n",
                "\n",
                "# Mixed precision:\n",
                "trainer = Trainer(amp_level='O1', precision=16)\n",
                "\n",
                "# Of course, you can combine these flags as well.\n",
                "```"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "ifDHkunjM8y6"
            },
            "source": [
                "# Evaluation of incorrectly predicted samples\n",
                "\n",
                "Given that we have a trained model, which performs reasonably well, let's try to listen to the samples where the model is least confident in its predictions.\n",
                "\n",
                "For this, we need the support of the librosa library.\n",
                "\n",
                "**NOTE**: The following code depends on librosa. To install it, run the following code block first."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "s3w3LhHcKuD2"
            },
            "source": [
                "!pip install librosa"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "PcJrZ72sNCkM"
            },
            "source": [
                "## Extract the predictions from the model\n",
                "\n",
                "We want to possess the actual logits of the model instead of just the final evaluation score, so we can define a function to perform the forward step for us without computing the final loss. Instead, we extract the logits per batch of samples provided."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "rvxdviYtOFjK"
            },
            "source": [
                "## Accessing the data loaders\n",
                "\n",
                "We can utilize the `setup_test_data` method in order to instantiate a data loader for the dataset we want to analyze.\n",
                "\n",
                "For convenience, we can access these instantiated data loaders using the following accessors - `asr_model._train_dl`, `asr_model._validation_dl` and `asr_model._test_dl`."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "CB0QZCAmM656"
            },
            "source": [
                "asr_model.setup_test_data(config.model.test_ds)\n",
                "test_dl = asr_model._test_dl"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "rA7gXawcPoip"
            },
            "source": [
                "## Partial Test Step\n",
                "\n",
                "Below we define a utility function to perform most of the test step. For reference, the test step is defined as follows:\n",
                "\n",
                "```python\n",
                "    def test_step(self, batch, batch_idx, dataloader_idx=0):\n",
                "        audio_signal, audio_signal_len, labels, labels_len = batch\n",
                "        logits = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)\n",
                "        loss_value = self.loss(logits=logits, labels=labels)\n",
                "        correct_counts, total_counts = self._accuracy(logits=logits, labels=labels)\n",
                "        return {'test_loss': loss_value, 'test_correct_counts': correct_counts, 'test_total_counts': total_counts}\n",
                "```"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "sBsDOm5ROpQI"
            },
            "source": [
                "@torch.no_grad()\n",
                "def extract_logits(model, dataloader):\n",
                "  logits_buffer = []\n",
                "  label_buffer = []\n",
                "\n",
                "  # Follow the above definition of the test_step\n",
                "  for batch in dataloader:\n",
                "    audio_signal, audio_signal_len, labels, labels_len = batch\n",
                "    logits = model(input_signal=audio_signal, input_signal_length=audio_signal_len)\n",
                "\n",
                "    logits_buffer.append(logits)\n",
                "    label_buffer.append(labels)\n",
                "    print(\".\", end='')\n",
                "  print()\n",
                "  \n",
                "  print(\"Finished extracting logits !\")\n",
                "  logits = torch.cat(logits_buffer, 0)\n",
                "  labels = torch.cat(label_buffer, 0)\n",
                "  return logits, labels\n"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "mZSdprUlOuoV"
            },
            "source": [
                "cpu_model = asr_model.cpu()\n",
                "cpu_model.eval()\n",
                "logits, labels = extract_logits(cpu_model, test_dl)\n",
                "print(\"Logits:\", logits.shape, \"Labels :\", labels.shape)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "9Wd0ukgNXRBz"
            },
            "source": [
                "# Compute accuracy - `_accuracy` is a PyTorch Lightning Metric !\n",
                "acc = cpu_model._accuracy(logits=logits, labels=labels)\n",
                "print(\"Accuracy : \", float(acc[0]*100))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "NwN9OSqCauSH"
            },
            "source": [
                "## Filtering out incorrect samples\n",
                "Let us now filter out the incorrectly labeled samples from the total set of samples in the test set"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "N1YJvsmcZ0uE"
            },
            "source": [
                "import librosa\n",
                "import json\n",
                "import IPython.display as ipd"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "jZAT9yGAayvR"
            },
            "source": [
                "# First let's create a utility class to remap the integer class labels to actual string label\n",
                "class ReverseMapLabel:\n",
                "    def __init__(self, data_loader):\n",
                "        self.label2id = dict(data_loader.dataset.label2id)\n",
                "        self.id2label = dict(data_loader.dataset.id2label)\n",
                "\n",
                "    def __call__(self, pred_idx, label_idx):\n",
                "        return self.id2label[pred_idx], self.id2label[label_idx]"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "X3GSXvYHa4KJ"
            },
            "source": [
                "# Next, let's get the indices of all the incorrectly labeled samples\n",
                "sample_idx = 0\n",
                "incorrect_preds = []\n",
                "rev_map = ReverseMapLabel(test_dl)\n",
                "\n",
                "# Remember, evaluated_tensor = (loss, logits, labels)\n",
                "probs = torch.softmax(logits, dim=-1)\n",
                "probas, preds = torch.max(probs, dim=-1)\n",
                "\n",
                "total_count = cpu_model._accuracy.total_counts_k[0]\n",
                "incorrect_ids = (preds != labels).nonzero()\n",
                "for idx in incorrect_ids:\n",
                "    proba = float(probas[idx][0])\n",
                "    pred = int(preds[idx][0])\n",
                "    label = int(labels[idx][0])\n",
                "    idx = int(idx[0]) + sample_idx\n",
                "\n",
                "    incorrect_preds.append((idx, *rev_map(pred, label), proba))\n",
                "\n",
                "print(f\"Num test samples : {total_count.item()}\")\n",
                "print(f\"Num errors : {len(incorrect_preds)}\")\n",
                "\n",
                "# First lets sort by confidence of prediction\n",
                "incorrect_preds = sorted(incorrect_preds, key=lambda x: x[-1], reverse=False)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "0JgGo71gcDtD"
            },
            "source": [
                "## Examine a subset of incorrect samples\n",
                "Let's print out the (test id, predicted label, ground truth label, confidence) tuple of first 20 incorrectly labeled samples"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "x37wNJsNbcw0"
            },
            "source": [
                "for incorrect_sample in incorrect_preds[:20]:\n",
                "    print(str(incorrect_sample))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "tDnwYsDKcLv9"
            },
            "source": [
                "##  Define a threshold below which we designate a model's prediction as \"low confidence\""
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "dpvzeh4PcGJs"
            },
            "source": [
                "# Filter out how many such samples exist\n",
                "low_confidence_threshold = 0.25\n",
                "count_low_confidence = len(list(filter(lambda x: x[-1] <= low_confidence_threshold, incorrect_preds)))\n",
                "print(f\"Number of low confidence predictions : {count_low_confidence}\")"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "ERXyXvCAcSKR"
            },
            "source": [
                "## Let's hear the samples which the model has least confidence in !"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "kxjNVjX8cPNP"
            },
            "source": [
                "# First let's create a helper function to parse the manifest files\n",
                "def parse_manifest(manifest):\n",
                "    data = []\n",
                "    for line in manifest:\n",
                "        line = json.loads(line)\n",
                "        data.append(line)\n",
                "\n",
                "    return data"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "IWxqw5k-cUVd"
            },
            "source": [
                "# Next, let's create a helper function to actually listen to certain samples\n",
                "def listen_to_file(sample_id, pred=None, label=None, proba=None):\n",
                "    # Load the audio waveform using librosa\n",
                "    filepath = test_samples[sample_id]['audio_filepath']\n",
                "    audio, sample_rate = librosa.load(filepath)\n",
                "\n",
                "    if pred is not None and label is not None and proba is not None:\n",
                "        print(f\"Sample : {sample_id} Prediction : {pred} Label : {label} Confidence = {proba: 0.4f}\")\n",
                "    else:\n",
                "        print(f\"Sample : {sample_id}\")\n",
                "\n",
                "    return ipd.Audio(audio, rate=sample_rate)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "HPj1tFNIcXaU"
            },
            "source": [
                "# Now let's load the test manifest into memory\n",
                "test_samples = []\n",
                "with open(test_dataset, 'r') as test_f:\n",
                "    test_samples = test_f.readlines()\n",
                "\n",
                "test_samples = parse_manifest(test_samples)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "Nt7b_uiScZcC"
            },
            "source": [
                "# Finally, let's listen to all the audio samples where the model made a mistake\n",
                "# Note: This list of incorrect samples may be quite large, so you may choose to subsample `incorrect_preds`\n",
                "count = min(count_low_confidence, 20)  # replace this line with just `count_low_confidence` to listen to all samples with low confidence\n",
                "\n",
                "for sample_id, pred, label, proba in incorrect_preds[:count]:\n",
                "    ipd.display(listen_to_file(sample_id, pred=pred, label=label, proba=proba))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "gxLGGDvHW2kV"
            },
            "source": [
                "# Fine-tuning on a new dataset\n",
                "\n",
                "We currently trained our dataset on all 30/35 classes of the Google Speech Commands dataset (v1/v2).\n",
                "\n",
                "We will now show an example of fine-tuning a trained model on a subset of the classes, as a demonstration of fine-tuning.\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "mZAPGTzeXnuQ"
            },
            "source": [
                "## Preparing the data-subsets\n",
                "\n",
                "Let's select 2 of the classes, `yes` and `no` and prepare our manifests with this dataset."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "G1RI4GBNfjUW"
            },
            "source": [
                "import json"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "L3cFvN5vcbjb"
            },
            "source": [
                "def extract_subset_from_manifest(name: str, manifest_path: str, labels: list):\n",
                "  manifest_dir = os.path.split(manifest_path)[0]\n",
                "  labels = set(labels)\n",
                "  manifest_values = []\n",
                "\n",
                "  print(f\"Parsing manifest: {manifest_path}\")\n",
                "  with open(manifest_path, 'r') as f:\n",
                "    for line in f:\n",
                "      val = json.loads(line)\n",
                "\n",
                "      if val['command'] in labels:\n",
                "        manifest_values.append(val)\n",
                "\n",
                "  print(f\"Number of files extracted from dataset: {len(manifest_values)}\")\n",
                "\n",
                "  outpath = os.path.join(manifest_dir, name)\n",
                "  with open(outpath, 'w') as f:\n",
                "    for val in manifest_values:\n",
                "      json.dump(val, f)\n",
                "      f.write(\"\\n\")\n",
                "      f.flush()\n",
                "\n",
                "  print(\"Manifest subset written to path :\", outpath)\n",
                "  print()\n",
                "\n",
                "  return outpath"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "fXQ0N1evfqZ8"
            },
            "source": [
                "labels = [\"yes\", \"no\"]\n",
                "\n",
                "train_subdataset = extract_subset_from_manifest(\"train_subset.json\", train_dataset, labels)\n",
                "val_subdataset = extract_subset_from_manifest(\"val_subset.json\", val_dataset, labels)\n",
                "test_subdataset = extract_subset_from_manifest(\"test_subset.json\", test_dataset, labels)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "IO5pVNyKimiE"
            },
            "source": [
                "## Saving/Restoring a checkpoint\n",
                "\n",
                "There are multiple ways to save and load models in NeMo. Since all NeMo models are inherently Lightning Modules, we can use the standard way that PyTorch Lightning saves and restores models.\n",
                "\n",
                "NeMo also provides a more advanced model save/restore format, which encapsulates all the parts of the model that are required to restore that model for immediate use.\n",
                "\n",
                "In this example, we will explore both ways of saving and restoring models, but we will focus on the PyTorch Lightning method."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "lMKvrT88jZwC"
            },
            "source": [
                "### Saving and Restoring via PyTorch Lightning Checkpoints\n",
                "\n",
                "When using NeMo for training, it is advisable to utilize the `exp_manager` framework. It is tasked with handling checkpointing and logging (Tensorboard as well as WandB optionally!), as well as dealing with multi-node and multi-GPU logging.\n",
                "\n",
                "Since we utilized the `exp_manager` framework above, we have access to the directory where the checkpoints exist. \n",
                "\n",
                "`exp_manager` with the default settings will save multiple checkpoints for us - \n",
                "\n",
                "1) A few checkpoints from certain steps of training. They will have `--val_loss=` tags\n",
                "\n",
                "2) A checkpoint at the last epoch of training denotes by `-last`.\n",
                "\n",
                "3) If the model finishes training, it will also have a `--end` checkpoint."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "TcHTw5ErmQRi"
            },
            "source": [
                "import glob"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "5h8zMJHngUrV"
            },
            "source": [
                "print(exp_dir)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "F9K_Ct_hl8oU"
            },
            "source": [
                "# Let's list all the checkpoints we have\n",
                "checkpoint_dir = os.path.join(exp_dir, 'checkpoints')\n",
                "checkpoint_paths = list(glob.glob(os.path.join(checkpoint_dir, \"*.ckpt\")))\n",
                "checkpoint_paths"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "67fbB61umfb4"
            },
            "source": [
                "# We want the checkpoint saved after the final step of training\n",
                "final_checkpoint = list(filter(lambda x: \"-last.ckpt\" in x, checkpoint_paths))[0]\n",
                "print(final_checkpoint)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "ZADUzv02nknZ"
            },
            "source": [
                "### Restoring from a PyTorch Lightning checkpoint\n",
                "\n",
                "To restore a model using the `LightningModule.load_from_checkpoint()` class method."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "ywd9Qj4Xm3VC"
            },
            "source": [
                "restored_model = nemo_asr.models.EncDecClassificationModel.load_from_checkpoint(final_checkpoint)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "0f4GQa8vB1BB"
            },
            "source": [
                "## Prepare the model for fine-tuning\n",
                "\n",
                "Remember, the original model was trained for a 30/35 way classification task. Now we require only a subset of these models, so we need to modify the decoder head to support fewer classes.\n",
                "\n",
                "We can do this easily with the convenient function `EncDecClassificationModel.change_labels(new_label_list)`.\n",
                "\n",
                "By performing this step, we discard the old decoder head, but still, preserve the encoder!"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "iMCMds7pB16U"
            },
            "source": [
                "restored_model.change_labels(labels)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "rrspQ2QFtbCK"
            },
            "source": [
                "### Prepare the data loaders\n",
                "\n",
                "The restored model, upon restoration, will not attempt to set up any data loaders. \n",
                "\n",
                "This is so that we can manually set up any datasets we want - train and val to finetune the model, test in order to just evaluate, or all three to do both!\n",
                "\n",
                "The entire config that we used before can still be accessed via `ModelPT.cfg`, so we will use it in order to set up our data loaders. This also gives us the opportunity to set any additional parameters we wish to setup!"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "9JxhiZN5ulUl"
            },
            "source": [
                "import copy"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "qzHfTOkPowJo"
            },
            "source": [
                "train_subdataset_cfg = copy.deepcopy(restored_model.cfg.train_ds)\n",
                "val_subdataset_cfg = copy.deepcopy(restored_model.cfg.validation_ds)\n",
                "test_subdataset_cfg = copy.deepcopy(restored_model.cfg.test_ds)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "it9-vFX6vHUl"
            },
            "source": [
                "# Set the paths to the subset of the dataset\n",
                "train_subdataset_cfg.manifest_filepath = train_subdataset\n",
                "val_subdataset_cfg.manifest_filepath = val_subdataset\n",
                "test_subdataset_cfg.manifest_filepath = test_subdataset"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "1qzWY8QDvgfc"
            },
            "source": [
                "# Setup the data loader for the restored model\n",
                "restored_model.setup_training_data(train_subdataset_cfg)\n",
                "restored_model.setup_multiple_validation_data(val_subdataset_cfg)\n",
                "restored_model.setup_multiple_test_data(test_subdataset_cfg)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "y8GZ5a5rC0gY"
            },
            "source": [
                "# Check data loaders are correct\n",
                "print(\"Train dataset labels :\", restored_model._train_dl.dataset.labels)\n",
                "print(\"Val dataset labels :\", restored_model._validation_dl.dataset.labels)\n",
                "print(\"Test dataset labels :\", restored_model._test_dl.dataset.labels)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "76yDcWZ9zl2G"
            },
            "source": [
                "## Setting up a new Trainer and Experiment Manager\n",
                "\n",
                "A restored model has a utility method to attach the Trainer object to it, which is necessary in order to correctly set up the optimizer and scheduler!\n",
                "\n",
                "**Note**: The restored model does not contain the trainer config with it. It is necessary to create a new Trainer object suitable for the environment where the model is being trained. The template can be replicated from any of the training scripts.\n",
                "\n",
                "Here, since we already had the previous config object that prepared the trainer, we could have used it, but for demonstration, we will set up the trainer config manually."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "swTe3WvBzkBJ"
            },
            "source": [
                "# Setup the new trainer object\n",
                "# Let's modify some trainer configs for this demo\n",
                "# Checks if we have GPU available and uses it\n",
                "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n",
                "\n",
                "trainer_config = OmegaConf.create(dict(\n",
                "    devices=1,\n",
                "    accelerator=accelerator,\n",
                "    max_epochs=5,\n",
                "    max_steps=-1,  # computed at runtime if not set\n",
                "    num_nodes=1,\n",
                "    accumulate_grad_batches=1,\n",
                "    enable_checkpointing=False,  # Provided by exp_manager\n",
                "    logger=False,  # Provided by exp_manager\n",
                "    log_every_n_steps=1,  # Interval of logging.\n",
                "    val_check_interval=1.0,  # Set to 0.25 to check 4 times per epoch, or an int for number of iterations\n",
                "))\n",
                "print(OmegaConf.to_yaml(trainer_config))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "Nd_ej4bI3TIy"
            },
            "source": [
                "trainer_finetune = pl.Trainer(**trainer_config)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "WtGu5q5T32XA"
            },
            "source": [
                "### Setting the trainer to the restored model\n",
                "\n",
                "All NeMo models provide a convenience method `set_trainer()` in order to setup the trainer after restoration"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "BTozhedA3zpM"
            },
            "source": [
                "restored_model.set_trainer(trainer_finetune)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "XojTpEiI3TQa"
            },
            "source": [
                "exp_dir_finetune = exp_manager(trainer_finetune, config.get(\"exp_manager\", None))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "x_LSbmCQ3TUf"
            },
            "source": [
                "exp_dir_finetune = str(exp_dir_finetune)\n",
                "exp_dir_finetune"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "QT_mWWnSxPLv"
            },
            "source": [
                "## Setup optimizer + scheduler\n",
                "\n",
                "For a fine-tuning experiment, let's set up the optimizer and scheduler!\n",
                "\n",
                "We will use a much lower learning rate than before, and also swap out the scheduler from PolyHoldDecay to CosineDecay."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "TugHsePsxA5Q"
            },
            "source": [
                "optim_sched_cfg = copy.deepcopy(restored_model.cfg.optim)\n",
                "# Struct mode prevents us from popping off elements from the config, so let's disable it\n",
                "OmegaConf.set_struct(optim_sched_cfg, False)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "pZSo0sWPxwiG"
            },
            "source": [
                "# Lets change the maximum learning rate to previous minimum learning rate\n",
                "optim_sched_cfg.lr = 0.001\n",
                "\n",
                "# Lets change the scheduler\n",
                "optim_sched_cfg.sched.name = \"CosineAnnealing\"\n",
                "\n",
                "# \"power\" isn't applicable to CosineAnnealing so let's remove it\n",
                "optim_sched_cfg.sched.pop('power')\n",
                "\n",
                "# \"hold_ratio\" isn't applicable to CosineAnnealing, so let's remove it\n",
                "optim_sched_cfg.sched.pop('hold_ratio')\n",
                "\n",
                "# Set \"min_lr\" to lower value\n",
                "optim_sched_cfg.sched.min_lr = 1e-4\n",
                "\n",
                "print(OmegaConf.to_yaml(optim_sched_cfg))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "FqqyFF3Ey5If"
            },
            "source": [
                "# Now lets update the optimizer settings\n",
                "restored_model.setup_optimization(optim_sched_cfg)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "mdivgIPUzgP_"
            },
            "source": [
                "# We can also just directly replace the config inplace if we choose to\n",
                "restored_model.cfg.optim = optim_sched_cfg"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "3-lRyz2_Eyrl"
            },
            "source": [
                "## Fine-tune training step\n",
                "\n",
                "We fine-tune on the subset classification problem. Note, the model was originally trained on these classes (the subset defined here has already been trained on above).\n",
                "\n",
                "When fine-tuning on a truly new dataset, we will not see such a dramatic improvement in performance. However, it should still converge a little faster than if it was trained from scratch."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "nq-iHIgx6OId"
            },
            "source": [
                "### Monitor training progress via Tensorboard\n"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "PIacDWcD5vCR"
            },
            "source": [
                "if COLAB_ENV:\n",
                "  %tensorboard --logdir {exp_dir_finetune}\n",
                "else:\n",
                "  print(\"To use tensorboard, please use this notebook in a Google Colab environment.\")"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "r5_z1eW76fip"
            },
            "source": [
                "### Fine-tuning for 5 epochs"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "WH8rN6dA6V9S"
            },
            "source": [
                "trainer_finetune.fit(restored_model)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "lgV0s8auJpxV"
            },
            "source": [
                "### Evaluation on the Test set\n",
                "\n",
                "Let's compute the final score on the test set via `trainer.test(model)`"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "szpLp6XTDPaK"
            },
            "source": [
                "trainer_finetune.test(restored_model, ckpt_path=None)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "uNBAaf1FKcAZ"
            },
            "source": [
                "## Advanced Usage: Exporting a model in its entirety\n",
                "\n",
                "While most models can be easily serialized via the Experiment Manager as a PyTorch Lightning checkpoint, there are certain models where this is insufficient. \n",
                "\n",
                "Consider the case where a Model contains artifacts such as tokenizers or other intermediate file objects that cannot be so easily serialized into a checkpoint.\n",
                "\n",
                "For such cases, NeMo offers two utility functions that enable serialization of a Model + artifacts - `save_to` and `restore_from`.\n",
                "\n",
                "Further documentation regarding these methods can be obtained from the documentation pages on NeMo."
            ]
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "Dov9g2j8Lyjs"
            },
            "source": [
                "import tarfile"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "WNixPPFNJyNc"
            },
            "source": [
                "# Save a model as a tarfile\n",
                "restored_model.save_to(os.path.join(exp_dir_finetune, \"model.nemo\"))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "B2RHYNjjLrcW"
            },
            "source": [
                "# The above object is just a tarfile which can store additional artifacts.\n",
                "with tarfile.open(os.path.join(exp_dir_finetune, 'model.nemo')) as blob:\n",
                "  for item in blob:\n",
                "    print(item)"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "code",
            "metadata": {
                "id": "fRo04x3TLxdu"
            },
            "source": [
                "# Restore a model from a tarfile\n",
                "restored_model_2 = nemo_asr.models.EncDecClassificationModel.restore_from(os.path.join(exp_dir_finetune, \"model.nemo\"))"
            ],
            "execution_count": null,
            "outputs": []
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "LyIegk2CPNsI"
            },
            "source": [
                "## Conclusion\n",
                "Once the model has been restored, either via a PyTorch Lightning checkpoint or via the `restore_from` methods, one can finetune by following the above general steps."
            ]
        }
    ]
}
